#model.py
import torchvision.models as models
import torch.nn as nn
def get_model(num_classes):
    # 加载预训练的 ResNet50 模型
    model = models.resnet50(pretrained=True)
    # 全连接层
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

